#!/usr/bin/env python3

#
# Run the test, compare results against the benchmark
#

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boutdata import collect
from numpy import sqrt, max, abs, mean, array, log, polyfit
from sys import stdout, exit

# Display the plots as well as saving to file
show_plot = False

# List of NX values to use
nxlist = [16, 32, 64, 128]

# Only testing 2D (x, z) slices, so only need one processor
nproc = 1

# Variables to compare
varlist = ['a', 'b', 'c']
markers = ['bo', 'r^', 'kx']
labels = [r'$'+var+r'$' for var in varlist]

methods = {
    "hermitespline": 3,
    "lagrange4pt": 3,
    "bilinear": 2,
}



print("Making Interpolation test")
shell_safe("make > make.log")

print("Running Interpolation test")
success = True

for method in methods:
    print("------------------------------")
    print("Using {} interpolation".format(method))

    error_2 = {}
    error_inf = {}
    for var in varlist:
        error_2[var] = []             # L2 error (RMS)
        error_inf[var] = []           # Maximum error

    for nx in nxlist:
        dx = 1. / (nx)

        args = " mesh:nx={nx4} mesh:dx={dx} MZ={nx} interpolation:type={method}".format(
            nx4=nx + 4, dx=dx, nx=nx, method=method)

        cmd = "./test_interpolate"+args

        shell("rm data/BOUT.dmp.*.nc")

        s, out = launch_safe(cmd, nproc=nproc, pipe=True)
        with open("run.log.{}.{}".format(method, nx), "w") as f:
            f.write(out)

        # Collect output data
        for var in varlist:
            interp = collect(var+"_interp", path="data", xguards=False, info=False)
            solution = collect(var+"_solution", path="data", xguards=False, info=False)

            E = interp - solution

            l2 = float(sqrt(mean(E**2)))
            linf = float(max(abs(E)))

            error_2[var].append(l2)
            error_inf[var].append(linf)

            print("{0:s} : l-2 {1:.8f} l-inf {2:.8f}".format(var, l2, linf))

    dx = 1./array(nxlist)

    for var in varlist:
        fit = polyfit(log(dx), log(error_2[var]), 1)
        order = fit[0]
        stdout.write("{0:s} Convergence order = {1:.2f}".format(var, order))

        # Make sure scaling is at least 90% of expected order
        if order < 0.9 * methods[method]:
            print("............ FAIL")
            success = False
        else:
            print("............ PASS")

    if False:
        try:
            import matplotlib.pyplot as plt
            # Plot errors
            plt.figure()

            for var, mark, label in zip(varlist, markers, labels):
                plt.plot(dx, error_2[var], '-'+mark, label=label)
                plt.plot(dx, error_inf[var], '--'+mark)

            plt.legend(loc="upper left")
            plt.grid()

            plt.yscale('log')
            plt.xscale('log')

            plt.xlabel(r'Mesh spacing $\delta x$')
            plt.ylabel("Error norm")
            plt.title("Error scaling for {}".format(method))

            name = "error_scaling_{}.pdf".format(method)
            plt.savefig(name)
            print("Plot saved to {}".format(name))

            if show_plot:
                plt.show()
            plt.close()
        except ImportError:
            print("No matplotlib available")
    else:
        print("Plotting disabled")

print("------------------------------")
if success:
    print(" => All Interpolation tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
